{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp common._model_checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Checks for models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This file provides a set of unit tests for all models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "import pandas as pd\n",
    "import neuralforecast.losses.pytorch as losses\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic, generate_series"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "seed = 0\n",
    "test_size = 14\n",
    "FREQ = \"D\"\n",
    "\n",
    "# 1 series, no exogenous\n",
    "N_SERIES_1 = 1\n",
    "df = generate_series(n_series=N_SERIES_1, seed=seed, freq=FREQ, equal_ends=True)\n",
    "max_ds = df.ds.max() - pd.Timedelta(test_size, FREQ)\n",
    "Y_TRAIN_DF_1 = df[df.ds < max_ds]\n",
    "Y_TEST_DF_1 = df[df.ds >= max_ds]\n",
    "\n",
    "# 5 series, no exogenous\n",
    "N_SERIES_2 = 5\n",
    "df = generate_series(n_series=N_SERIES_2, seed=seed, freq=FREQ, equal_ends=True)\n",
    "max_ds = df.ds.max() - pd.Timedelta(test_size, FREQ)\n",
    "Y_TRAIN_DF_2 = df[df.ds < max_ds]\n",
    "Y_TEST_DF_2 = df[df.ds >= max_ds]\n",
    "\n",
    "# 1 series, with static and temporal exogenous\n",
    "N_SERIES_3 = 1\n",
    "df, STATIC_3 = generate_series(n_series=N_SERIES_3, n_static_features=2, \n",
    "                     n_temporal_features=2, seed=seed, freq=FREQ, equal_ends=True)\n",
    "max_ds = df.ds.max() - pd.Timedelta(test_size, FREQ)\n",
    "Y_TRAIN_DF_3 = df[df.ds < max_ds]\n",
    "Y_TEST_DF_3 = df[df.ds >= max_ds]\n",
    "\n",
    "# 5 series, with static and temporal exogenous\n",
    "N_SERIES_4 = 5\n",
    "df, STATIC_4 = generate_series(n_series=N_SERIES_4, n_static_features=2, \n",
    "                     n_temporal_features=2, seed=seed, freq=FREQ, equal_ends=True)\n",
    "max_ds = df.ds.max() - pd.Timedelta(test_size, FREQ)\n",
    "Y_TRAIN_DF_4 = df[df.ds < max_ds]\n",
    "Y_TEST_DF_4 = df[df.ds >= max_ds]\n",
    "\n",
    "# Generic test for a given config for a model\n",
    "def _run_model_tests(model_class, config):\n",
    "    if model_class.RECURRENT:\n",
    "        config[\"inference_input_size\"] = config[\"input_size\"]\n",
    "\n",
    "    # DF_1\n",
    "    if model_class.MULTIVARIATE:\n",
    "        config[\"n_series\"] = N_SERIES_1\n",
    "    if isinstance(config[\"loss\"], losses.relMSE):\n",
    "        config[\"loss\"].y_train = Y_TRAIN_DF_1[\"y\"].values   \n",
    "    if isinstance(config[\"valid_loss\"], losses.relMSE):\n",
    "        config[\"valid_loss\"].y_train = Y_TRAIN_DF_1[\"y\"].values   \n",
    "\n",
    "    model = model_class(**config)\n",
    "    fcst = NeuralForecast(models=[model], freq=FREQ)\n",
    "    fcst.fit(df=Y_TRAIN_DF_1, val_size=24)\n",
    "    _ = fcst.predict(futr_df=Y_TEST_DF_1)\n",
    "    # DF_2\n",
    "    if model_class.MULTIVARIATE:\n",
    "        config[\"n_series\"] = N_SERIES_2\n",
    "    if isinstance(config[\"loss\"], losses.relMSE):\n",
    "        config[\"loss\"].y_train = Y_TRAIN_DF_2[\"y\"].values   \n",
    "    if isinstance(config[\"valid_loss\"], losses.relMSE):\n",
    "        config[\"valid_loss\"].y_train = Y_TRAIN_DF_2[\"y\"].values\n",
    "    model = model_class(**config)\n",
    "    fcst = NeuralForecast(models=[model], freq=FREQ)\n",
    "    fcst.fit(df=Y_TRAIN_DF_2, val_size=24)\n",
    "    _ = fcst.predict(futr_df=Y_TEST_DF_2)\n",
    "\n",
    "    if model.EXOGENOUS_STAT and model.EXOGENOUS_FUTR:\n",
    "        # DF_3\n",
    "        if model_class.MULTIVARIATE:\n",
    "            config[\"n_series\"] = N_SERIES_3\n",
    "        if isinstance(config[\"loss\"], losses.relMSE):\n",
    "            config[\"loss\"].y_train = Y_TRAIN_DF_3[\"y\"].values   \n",
    "        if isinstance(config[\"valid_loss\"], losses.relMSE):\n",
    "            config[\"valid_loss\"].y_train = Y_TRAIN_DF_3[\"y\"].values\n",
    "        model = model_class(**config)\n",
    "        fcst = NeuralForecast(models=[model], freq=FREQ)\n",
    "        fcst.fit(df=Y_TRAIN_DF_3, static_df=STATIC_3, val_size=24)\n",
    "        _ = fcst.predict(futr_df=Y_TEST_DF_3)\n",
    "\n",
    "        # DF_4\n",
    "        if model_class.MULTIVARIATE:\n",
    "            config[\"n_series\"] = N_SERIES_4\n",
    "        if isinstance(config[\"loss\"], losses.relMSE):\n",
    "            config[\"loss\"].y_train = Y_TRAIN_DF_4[\"y\"].values   \n",
    "        if isinstance(config[\"valid_loss\"], losses.relMSE):\n",
    "            config[\"valid_loss\"].y_train = Y_TRAIN_DF_4[\"y\"].values \n",
    "        model = model_class(**config)\n",
    "        fcst = NeuralForecast(models=[model], freq=FREQ)\n",
    "        fcst.fit(df=Y_TRAIN_DF_4, static_df=STATIC_4, val_size=24)\n",
    "        _ = fcst.predict(futr_df=Y_TEST_DF_4) \n",
    "\n",
    "# Tests a model against every loss function\n",
    "def check_loss_functions(model_class):\n",
    "    loss_list = [losses.MAE(), losses.MSE(), losses.RMSE(), losses.MAPE(), losses.SMAPE(), losses.MASE(seasonality=7), \n",
    "              losses.QuantileLoss(q=0.5), losses.MQLoss(), losses.IQLoss(), losses.HuberIQLoss(), losses.DistributionLoss(\"Normal\"), \n",
    "              losses.DistributionLoss(\"StudentT\"), losses.DistributionLoss(\"Poisson\"), losses.DistributionLoss(\"NegativeBinomial\"), \n",
    "              losses.DistributionLoss(\"Tweedie\", rho=1.5), losses.DistributionLoss(\"ISQF\"), losses.PMM(), losses.PMM(weighted=True), \n",
    "              losses.GMM(), losses.GMM(weighted=True), losses.NBMM(), losses.NBMM(weighted=True), losses.HuberLoss(), \n",
    "            losses.TukeyLoss(), losses.HuberQLoss(q=0.5), losses.HuberMQLoss()]\n",
    "    for loss in loss_list:\n",
    "        test_name = f\"{model_class.__name__}: checking {loss._get_name()}\"\n",
    "        print(f\"{test_name}\")\n",
    "        config = {'max_steps': 2,\n",
    "            'h': 7,\n",
    "            'input_size': 28,\n",
    "            'loss': loss,\n",
    "            'valid_loss': None,\n",
    "            'enable_progress_bar': False,\n",
    "            'enable_model_summary': False,\n",
    "            'val_check_steps': 2}        \n",
    "        try:\n",
    "            _run_model_tests(model_class, config) \n",
    "        except RuntimeError:\n",
    "            raise Exception(f\"{test_name} failed.\")\n",
    "        except Exception:\n",
    "            print(f\"{test_name} skipped on raised Exception.\")\n",
    "            pass\n",
    "\n",
    "# Tests a model against the AirPassengers dataset\n",
    "def check_airpassengers(model_class):\n",
    "    print(f\"{model_class.__name__}: checking forecast AirPassengers dataset\")\n",
    "    Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]] # 132 train\n",
    "    Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
    "\n",
    "    config = {'max_steps': 2,\n",
    "        'h': 12,\n",
    "        'input_size': 24,\n",
    "        'enable_progress_bar': False,\n",
    "        'enable_model_summary': False,\n",
    "        'val_check_steps': 2,\n",
    "        }\n",
    "\n",
    "    if model_class.MULTIVARIATE:\n",
    "        config[\"n_series\"] = Y_train_df[\"unique_id\"].nunique()\n",
    "    # Normal forecast\n",
    "    fcst = NeuralForecast(models=[model_class(**config)], freq='M')\n",
    "    fcst.fit(df=Y_train_df, static_df=AirPassengersStatic)\n",
    "    _ = fcst.predict(futr_df=Y_test_df)   \n",
    "\n",
    "    # Cross-validation\n",
    "    fcst = NeuralForecast(models=[model_class(**config)], freq='M')\n",
    "    _ = fcst.cross_validation(df=AirPassengersPanel, static_df=AirPassengersStatic, n_windows=2, step_size=12)\n",
    "\n",
    "# Add unit test functions to this function\n",
    "def check_model(model_class, checks=[\"losses\", \"airpassengers\"]):\n",
    "    \"\"\"\n",
    "    Check model with various tests. Options for checks are:<br>\n",
    "    \"losses\": test the model against all loss functions<br>\n",
    "    \"airpassengers\": test the model against the airpassengers dataset for forecasting and cross-validation<br>\n",
    "    \n",
    "    \"\"\"\n",
    "    if \"losses\" in checks:\n",
    "        check_loss_functions(model_class)   \n",
    "    if \"airpassengers\" in checks:\n",
    "        try:\n",
    "            check_airpassengers(model_class)   \n",
    "        except RuntimeError:\n",
    "            raise Exception(f\"{model_class.__name__}: AirPassengers forecast test failed.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| eval: false\n",
    "#| hide\n",
    "# Run tests in this file. This is a slow test\n",
    "import warnings\n",
    "import logging\n",
    "from neuralforecast.models import RNN, GRU, TCN, LSTM, DeepAR, DilatedRNN, BiTCN, MLP, NBEATS, NBEATSx, NHITS, DLinear, NLinear, TiDE, DeepNPTS, TFT, VanillaTransformer, Informer, Autoformer, FEDformer, TimesNet, iTransformer, KAN, RMoK, StemGNN, TSMixer, TSMixerx, MLPMultivariate, SOFTS, TimeMixer\n",
    "\n",
    "models = [RNN, GRU, TCN, LSTM, DeepAR, DilatedRNN, BiTCN, MLP, NBEATS, NBEATSx, NHITS, DLinear, NLinear, TiDE, DeepNPTS, TFT, VanillaTransformer, Informer, Autoformer, FEDformer, TimesNet, iTransformer, KAN, RMoK, StemGNN, TSMixer, TSMixerx, MLPMultivariate, SOFTS, TimeMixer]\n",
    "\n",
    "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)\n",
    "logging.getLogger(\"lightning_fabric\").setLevel(logging.ERROR)\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    for model in models:\n",
    "        check_model(model, checks=[\"losses\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
